-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
removing dependency on setting the Keras backend #79
Conversation
@@ -103,7 +106,7 @@ def __init__( | |||
self.distribution_class = prior.__class__ | |||
self.encoder = encoder | |||
self.compositor = ( | |||
compositor if compositor is not None else tf.keras.layers.Concatenate(axis=-1) | |||
compositor if compositor is not None else tf.keras.layers.Concatenate(axis=-1, dtype=default_float()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this needed to be passed an explicit dtype as well
(fit_adam, fit_adam, {}), | ||
(fit_adam, keras_fit_adam, {}), | ||
(fit_natgrad, fit_natgrad, {}), | ||
(fit_natgrad, keras_fit_natgrad, dict(atol=1e-7)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this test required lowering tolerance slightly...
|
||
|
||
@pytest.mark.parametrize( | ||
"svgp_fitter, keras_fitter, tol_kw", | ||
[ | ||
(fit_adam, _keras_fit_adam, {}), | ||
(fit_natgrad, _keras_fit_natgrad, dict(atol=1e-8)), | ||
(fit_natgrad, _keras_fit_natgrad, dict(atol=1e-6)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this one as well... seems like there was an issue with this one already, so I assumed it might be ok?
but not sure where the small differences arise really, perhaps keras is still somewhere using float32, not sure how exactly to check that - any idea?
tf.keras.backend.set_floatx("float64") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
didnt try to run this as it needed additional packages
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
this PR addresses #76, fix was fairly straightforward, the rest was removing backend calls throughout the repo
there was a tiny bit discrepancy in some tests in
GPflux/tests/integration/test_svgp_equivalence.py
, couldn't pinpoint where it was coming fromwhile this avoids setting the backend, one still needs to be careful that float64 data is given to the model (which is the case in all the tests at the moment), @vdutor how do you want to go about this? just warn the user in docstrings and some notebooks, or implement data check in
DeepGP
class (perhaps somewhere else as well?)?